from implement.variables.parameter import Parameter


class FreezeParam:
    def __init__(self, *layers):
        """
        冻结指定层的参数，使其在梯度更新时不被修改。

        Args:
            *layers: 传入要冻结参数的层或参数。
        """
        self.freeze_params = []
        for l in layers:
            if isinstance(l, Parameter):
                self.freeze_params.append(l)
            else:
                for p in l.params():
                    self.freeze_params.append(p)

    def __call__(self, params):
        """
        将指定层的参数梯度置为None，以实现参数的冻结。

        Args:
            params (list): 参数列表。

        Returns:
            None
        """
        for p in self.freeze_params:
            p.grad = None


